import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from pdb import set_trace
from matplotlib.ticker import MaxNLocator

# Loading the data from the CSV files (assuming the files are named 'data1.csv' and 'data2.csv')
# df= pd.read_csv(f'all_runs_data_v2.csv')  # used for all tables in the paper
# df_name = 'all_runs_data_4.csv' # used for the table that includes the upper limit metric
df_name = 'all_runs_data_5.csv'   # in these results, bidder 1 has a multiplier of 2, and bidder 0 of 1. 

df = pd.read_csv(df_name)   # in these results, bidder 1 has a multiplier of 2, and bidder 0 of 1. 
# metric = 'total advertiser utility gain zero bid offset'
# metric = 'total advertiser participating value gain'
# metric = 'reference LLM log probability'
# metric = 'sequence log probability'
metric = 'total payment zero bid offset'
# metric = 'total payment no offset'
# metric = 'advertiser 0 expected value'
# metrics = ['advertiser 0 expected value', 'advertiser 1 expected value']
# metric options: 'total advertiser participating value gain', 'total advertiser utility gain zero bid offset', 
# 'total payment zero bid offset',  'total payment no offset', 'sequence log probability', 'reference LLM log probability', 


# Plot hyperparameters
if metric == 'total advertiser participating value gain':
    metric_plot_name = 'Total Advertiser Reward Gain'
elif metric == 'total payment zero bid offset':
    metric_plot_name = 'Revenue'
elif metric == 'total advertiser utility gain zero bid offset':
    metric_plot_name = 'Total Advertiser Utility Gain'
elif metric == 'sequence log probability':
    metric_plot_name = 'Reply Log Probability'
elif metric == 'total payment no offset':
    metric_plot_name = 'Revenue'
elif metric == 'reference LLM log probability':
    metric_plot_name = 'Reply Log Probability wrt $ \pi_{ref} $'
else: 
    metric_plot_name = metric


upper_limit_metric_name = 'reference LLM log probability'
font_size = 25
tick_size = 18
save_plot = True   
no_gain_in_name = False  # set to True to remove the ward gain from teh metric name 
if metric in ['sequence log probability', 'reference LLM log probability']:
    add_upper_limit = True
else:
    add_upper_limit = False


def mean_confidence_interval(data, confidence=0.95):
    mean = np.mean(data)
    sem = stats.sem(data)
    margin_of_error = sem * stats.t.ppf((1 + confidence) / 2., len(data)-1)
    return mean, mean - margin_of_error, mean + margin_of_error

# Unpack confidence intervals for plotting
def unpack_confidence_intervals(confidence_intervals):
    means = confidence_intervals.apply(lambda x: x[0])
    lower_bounds = confidence_intervals.apply(lambda x: x[1])
    upper_bounds = confidence_intervals.apply(lambda x: x[2])
    return means, lower_bounds, upper_bounds


if __name__ == '__main__':

    df_with_expansion = df[df['use_input_expansion'] == True]
    df_without_expansion = df[df['use_input_expansion'] == False]


    # For data with expansion
    grouped_with_expansion = df_with_expansion.groupby('samples used')
    ci_with_expansion = grouped_with_expansion[metric].apply(mean_confidence_interval)
    means_with, lower_with, upper_with = unpack_confidence_intervals(ci_with_expansion)

    # For data without expansion
    grouped_without_expansion = df_without_expansion.groupby('samples used')
    ci_without_expansion = grouped_without_expansion[metric].apply(mean_confidence_interval)
    means_without, lower_without, upper_without = unpack_confidence_intervals(ci_without_expansion)

    # For upper limit metric
    if add_upper_limit:
        # the upper limit metric is the reference LLM log probability, which generated the sequences only in the case of the no expansion
        ci_without_expansion_upper_limit = grouped_without_expansion[upper_limit_metric_name].apply(mean_confidence_interval)  
        means_without_upper_limit, lower_without_upper_limit, upper_without_upper_limit = unpack_confidence_intervals(ci_without_expansion_upper_limit)

        # For all entries, the upper limit metric should be set to its 0-th entry 
        # set_trace()
        means_without_upper_limit = means_without_upper_limit[1].repeat(len(means_without_upper_limit))
        lower_without_upper_limit = lower_without_upper_limit[1].repeat(len(lower_without_upper_limit))
        upper_without_upper_limit = upper_without_upper_limit[1].repeat(len(upper_without_upper_limit))
        # set_trace()

    plt.figure(figsize=(12, 8))

    # Plot for with expansion
    plt.plot(means_with.index, means_with, label=r'Using $\pi_{\text{con}}(\cdot | x)$', color='orange')
    plt.fill_between(means_with.index, lower_with, upper_with, color='orange', alpha=0.2)

    # Plot for without expansion
    plt.plot(means_without.index, means_without, label=r'Using $\pi_{\text{ref}}(\cdot | x)$', color='blue')
    plt.fill_between(means_without.index, lower_without, upper_without, color='blue', alpha=0.2)

    # Plot for upper limit metric
    if add_upper_limit:
        # plt.plot(means_without.index, means_without_upper_limit, label=r'Using $\hat{\pi}_{\text{opt}}(\cdot | x)$', color='green', linestyle='--')
        if metric == 'reference LLM log probability':
            plt.plot(means_without.index, means_without_upper_limit, label=r'Single sample $\sim \hat{\pi}_{\text{ref}}(\cdot | x)$', color='green', linestyle='--')
        else:
            plt.plot(means_without.index, means_without_upper_limit, label=r'Single sample $\sim \hat{\pi}_{\text{opt}}(\cdot | x)$', color='green', linestyle='--')
        plt.fill_between(means_without.index, lower_without_upper_limit, upper_without_upper_limit, color='green', alpha=0.2)

    plt.xlabel('Candidate Replies Generated', fontsize=font_size)
    # remove the word 'gain' from the metric name to show in the y-axis
    if no_gain_in_name:
        metric_name_to_show = metric_plot_name.replace('Gain', '')
    else:
        metric_name_to_show = metric_plot_name
    plt.ylabel(metric_name_to_show, fontsize=font_size)
    # plt.title(f'{metric_plot_name} vs. Generated Candidate Sequences', fontsize=font_size)  # No need for title, as this will be added in the paper
    plt.legend(fontsize=font_size)
    plt.grid(True)


    # Setting x-axis limits slightly beyond the actual data points for better visibility
    plt.xlim(0.5, 20.5)

    # Specifying tick locations explicitly to ensure all are shown, including 1 and 20
    plt.xticks(np.arange(1, 21, step=1))

    # Making tick labels larger
    plt.tick_params(axis='both', which='major', labelsize=tick_size)  



    
    if save_plot:
        format = 'pdf'
        if df_name == 'all_runs_data_5.csv':
            savefolder = f'plots_bidder_multipliers_1_2'
        else:
            savefolder = f'plots_bidder_multipliers_1_1'

        if not add_upper_limit:
            savename = f'./{savefolder}/{metric}_plot.{format}'
        else:
            savename = f'./{savefolder}/{metric}_plot_with_upper_limit.{format}'
        
        # replace all spaces with underscores
        savename = savename.replace(' ', '_')
        plt.savefig(savename)
    else:
        plt.show()
